import numpy as np 
import pandas as pd
import subprocess
import os
import torchvision.models as models
import torchvision
import torch
import matplotlib.pyplot as plt
import scipy
import scipy.ndimage
import time
import random

from tqdm import tqdm

from copy import deepcopy

from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score

from torchsummary import summary

from xml.dom import minidom
from os.path import basename

from PIL import Image

from scipy.ndimage.filters import maximum_filter
from scipy.ndimage.morphology import generate_binary_structure, binary_erosion


dataroot     = '/annon'
xml_folder   = dataroot + '/data/ILSVRC/Annotations/CLS-LOC/train/'
train_folder = dataroot + '/data/ILSVRC/Data/CLS-LOC/train/'
val_folder   = dataroot + '/data/ILSVRC/Data/CLS-LOC/val/'
IMG_SIZE     = 224
WORKERS      = 1
BATCH_SIZE   = 256
MEAN_NORM    = (0.485, 0.456, 0.406)
STD_NORM     = (0.229, 0.224, 0.225)
FILTER_SIZE  = 7
DEVICE       = 'cpu'


def main():
    run()


class NetC(torch.nn.Module):

    def __init__(self, net):
        super(NetC, self).__init__()
        self.net = net
        self.main = torch.nn.Sequential(*list(self.net.children()))[:-2]
        self.avgpool = net.avgpool
        self.linear = net.fc 

    def forward(self, x):
        C = self.main(x)      
        x = self.avgpool(C)  
        x = x.view(x.shape[0], x.shape[1])
        logits = self.linear(x)
        return logits, x, C


class netClassifier(torch.nn.Module):
    
    def __init__(self, netC):
        super(netClassifier, self).__init__()
        self.net = netC
        
    def forward(self, C):
        x = self.net.avgpool(C)
        x = x.view(-1, 2048)
        logits = self.net.linear(x)
        return logits


class EnsembleTwin():
    
    def __init__(self):
        self.twins = list()
    
    def fit(self, X_train_c, train_preds, num_classes=1000):
        
        
        for i in range(num_classes):
            
            twin = KNeighborsClassifier()
            mask = (train_preds == i)
            
            mask_idxs = list()
            for i in range(len(mask)):
                if mask[i] == True:
                    mask_idxs.append(i)
                    
            twin_conts = X_train_c[mask]
            twin_preds = train_preds[mask]
            twin.fit(twin_conts, twin_preds)
            self.twins.append([twin, mask_idxs])
        
    def predict(self, query_cont, query_pred, nns=1):
        idxs = self.twins[query_pred][0].kneighbors(X=[query_cont], 
                                                 n_neighbors=nns, 
                                                 return_distance=False)[0]
        real_idxs = self.twins[query_pred][1]
        results = list()
        for i in range(len(idxs)):
            real_idx = real_idxs[ idxs[i] ]
            results.append(real_idx)
        return results


def evaluate_expt1(la, lb, df, tech, rand_layer):
    
    nns = [50]
    for i in tqdm(range(len(la))):

        row = pd.DataFrame(columns=['Overlap', 'Technique', 'NNs', 'LayerRand'])

        a = la[i]
        b = lb[i]

        overlap = len(set(a) & set(b)) / len(a)

        row['Overlap'] = [overlap]
        row['Technique'] = tech
        row['NNs'] = 50
        row['LayerRand'] = rand_layer

        df = pd.concat([df, row])
            
    return df


def run():


    train_loader, val_loader, train_dataset, val_dataset = imagenet_dataloaders()
    FOLDER = 'twin_data'
    resnet = models.resnet50(pretrained=True).eval()
    netC = NetC(resnet)

    X_train_cont = np.load(dataroot + '/sonic_collect_twin_data_imagenet/'+FOLDER+'/X_train_cont.npy')
    X_train_act  = np.load(dataroot + '/sonic_collect_twin_data_imagenet/'+FOLDER+'/X_train_act.npy')
    y_train_pred = np.load(dataroot + '/sonic_collect_twin_data_imagenet/'+FOLDER+'/y_train_pred.npy')
    X_val_cont   = np.load(dataroot + '/sonic_collect_twin_data_imagenet/'+FOLDER+'/X_val_cont.npy')
    X_val_act    = np.load(dataroot + '/sonic_collect_twin_data_imagenet/'+FOLDER+'/X_val_act.npy')
    y_val_pred   = np.load(dataroot + '/sonic_collect_twin_data_imagenet/'+FOLDER+'/y_val_pred.npy')

    # Make Twin Explainer
    twin = KNeighborsClassifier(n_neighbors=1)
    twin.fit(X_train_cont, y_train_pred)

    ensemble_twin = EnsembleTwin()
    ensemble_twin.fit(X_train_cont, y_train_pred)
    WEIGHTS = netC.linear.weight


    #### First collect NN and query box information from Original Weights
    twin_nns = list()
    ensemble_nns = list()

    for query_idx, data in enumerate(tqdm(val_loader)):
        
        # Get data
        img, label = data
        img, label = img.to(DEVICE), label.to(DEVICE)        
            
        #### Twin
        query_logits, query_x, query_C = netC(img)
        query_pred = torch.argmax(query_logits, dim=1)[0].item()
        query_cont = WEIGHTS[query_pred] * query_x[0]
        
        start_time = time.time()
        nn_idxs = twin.kneighbors(X=[query_cont.detach().numpy()], n_neighbors=50, return_distance=False)[0]
        print("Time:", time.time() - start_time)
        
        twin_nns.append(nn_idxs.tolist())
        
        
        #### Ensemble Twin
        query_logits, query_x, query_C = netC(img)
        query_pred = torch.argmax(query_logits, dim=1)[0].item()
        query_cont = WEIGHTS[query_pred] * query_x[0]
        
        start_time = time.time()
        nn_idxs = ensemble_twin.predict(query_cont.detach().numpy(), query_pred, nns=50)
        print("Time:", time.time() - start_time)

        ensemble_nns.append(nn_idxs)
        
        if query_idx == 5000:
            break

    df = pd.DataFrame(columns=['Overlap', 'Technique', 'NNs', 'LayerRand'])
    df = evaluate_expt1(twin_nns[:10], ensemble_nns[:10], df, 'Compare', 'Normal')
    df.Overlap.mean()


if __name__ == '__main__':
    main()



